# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
from functools import partial
from typing import Iterator, Literal, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import ott.utils as ott_utils

Name_t = Literal["simple", "circle", "square_five", "square_four"]
__all__ = ["create_gaussian_mixture_samplers", "Dataset", "GaussianMixture"]


class Dataset(NamedTuple):
  r"""Samplers from source and target measures.

  Args:
    source_iter: loader for the source measure
    target_iter: loader for the target measure
  """
  source_iter: Iterator[jnp.ndarray]
  target_iter: Iterator[jnp.ndarray]


@dataclasses.dataclass
class GaussianMixture:
  """A mixture of Gaussians.

  Args:
    name: the name specifying the centers of the mixture components:

      - ``simple`` - data clustered in one center,
      - ``circle`` - two-dimensional Gaussians arranged on a circle,
      - ``square_five`` - two-dimensional Gaussians on a square with
        one Gaussian in the center, and
      - ``square_four`` - two-dimensional Gaussians in the corners of a
        rectangle

    batch_size: batch size of the samples
    rng: initial PRNG key
    scale: scale of the Gaussian means
    std: the standard deviation of the individual Gaussian samples
  """
  name: Name_t
  batch_size: int
  rng: jax.Array
  scale: float = 1.0
  std: float = 1.0

  def __post_init__(self) -> None:
    gaussian_centers = {
        "simple":
            np.array([[0, 0]]),
        "circle":
            np.array([
                (1, 0),
                (-1, 0),
                (0, 1),
                (0, -1),
                (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
                (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
                (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
                (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
            ]),
        "square_five":
            np.array([[0, 0], [1, 1], [-1, 1], [-1, -1], [1, -1]]),
        "square_four":
            np.array([[1, 0], [0, 1], [-1, 0], [0, -1]]),
        "pi0":  np.array([[4., 0.], [0., 4.], [-4., 0.,], [0., -4.]]),
        "pi1":  np.array([[16, 0], [11.31, 11.31], [0, 16,], [-11.31, 11.31], [-16, 0], [-11.31, -11.31], [0, -16], [11.31, -11.31]])
    }
    self.centers_src = gaussian_centers["pi0"]
    self.centers_tgt = gaussian_centers["pi1"]

  def __iter__(self) -> Iterator[jnp.array]:
    """Random sample generator from Gaussian mixture.

    Returns:
      A generator of samples from the Gaussian mixture.
    """
    return self._create_sample_generators()

  def _create_sample_generators(self) -> Iterator[jnp.array]:
    rng = self.rng
    while True:
      rng1, rng2, rng = jax.random.split(rng, 3)
      means = jax.random.choice(rng1, self.centers_src, (self.batch_size,))
      normal_samples = jax.random.normal(rng2, (self.batch_size, 2))
      source_samples = self.scale * means + self.std * normal_samples

      rng1, rng2, rng = jax.random.split(rng, 3)
      means = jax.random.choice(rng1, self.centers_tgt, (self.batch_size,))
      normal_samples = jax.random.normal(rng2, (self.batch_size, 2))
      target_samples = self.scale * means + self.std * normal_samples
      
      yield {"src_lin": source_samples, 'tgt_lin': target_samples}


def create_gaussian_mixture_samplers(
    name_source: Name_t,
    name_target: Name_t,
    train_batch_size: int = 2048,
    valid_batch_size: int = 2048,
    rng: Optional[jax.Array] = None,
) -> Tuple[Dataset, Dataset, int]:
  """Gaussian samplers.

  Args:
    name_source: name of the source sampler
    name_target: name of the target sampler
    train_batch_size: the training batch size
    valid_batch_size: the validation batch size
    rng: initial PRNG key

  Returns:
    The dataset and dimension of the data.
  """
  rng = ott_utils.default_prng_key(rng)
  rng1, rng2, rng3, rng4 = jax.random.split(rng, 4)
  train_dataset = Dataset(
      source_iter=iter(
          GaussianMixture(name_source, batch_size=train_batch_size, rng=rng1)
      ),
      target_iter=iter(
          GaussianMixture(name_target, batch_size=train_batch_size, rng=rng2)
      )
  )
  valid_dataset = Dataset(
      source_iter=iter(
          GaussianMixture(name_source, batch_size=valid_batch_size, rng=rng3)
      ),
      target_iter=iter(
          GaussianMixture(name_target, batch_size=valid_batch_size, rng=rng4)
      )
  )
  dim_data = 2
  return train_dataset, valid_dataset, dim_data

class UniformLineDataset:
    def __init__(self, size):#, src_mean, tgt_mean):
        self.size = size
        # self.src_mean = src_mean
        # self.tgt_mean = tgt_mean
        
    def __iter__(self):
        rng = jax.random.PRNGKey(42)
        while True:
            rng, sample_key = jax.random.split(rng, 2)
            yield UniformLineDataset._sample(sample_key, self.size)

    @staticmethod
    @partial(jax.jit, static_argnums=1)
    def _sample(key, batch_size):
        k1, k2, key = jax.random.split(key, 3)
        x1 = jax.random.uniform(k1, (batch_size, 1), minval=-1.25, maxval=-1.0)
        x2 = jax.random.uniform(k2, (batch_size, 1), minval=-1.0, maxval=1.0)
        x_0 = jnp.concatenate([x1, x2], axis=1)
        
        k1, k2, key = jax.random.split(key, 3)
        x1 = jax.random.uniform(k1, (batch_size, 1), minval=1, maxval=1.25)
        x2 = jax.random.uniform(k2, (batch_size, 1), minval=-1.0, maxval=1.0)
        x_1 = jnp.concatenate([x1, x2], axis=1)

        return {
            "src_lin": x_0,
            "tgt_lin": x_1
        }

@dataclasses.dataclass
class Gaussian:
    source_mean: float
    source_var: float

    target_mean: float
    target_var: float

    batch_size: int
    init_key: jax.random.PRNGKey

    def __iter__(self) -> Iterator[jnp.array]:
        """Random sample generator from Gaussian mixture.
        Returns:
        A generator of samples from the Gaussian mixture.
        """
        return self._create_sample_generators()

    def _create_sample_generators(self) -> Iterator[jnp.array]:
        key = self.init_key
        while True:
            key1, key2, key = jax.random.split(key, 3)
            source_normal_samples = jax.random.normal(key1, [self.batch_size, 2])
            source_samples = self.source_mean + self.source_var * source_normal_samples

            target_normal_samples = jax.random.normal(key2, [self.batch_size, 2])
            target_samples = self.target_mean + self.target_var * target_normal_samples

            yield {"src_lin": source_samples, 'tgt_lin': target_samples}

            
def create_lagrangian_ds(geometry_str: str, batch_size: int, key):
  if geometry_str == "babymaze":
    return UniformLineDataset(size=batch_size)
    # variance = 0.1
    # source_mean = jnp.array([-1.5, 0.5])
    # target_mean = jnp.array([1.5, -0.0])
    
  elif geometry_str == "box":
    return UniformLineDataset(size=batch_size)
  
  elif geometry_str == "vneck":
    variance = 0.2
    source_mean = jnp.array([-7, 0.0])
    target_mean = jnp.array([7, 0.0])
  
  elif geometry_str == "slit":
    return UniformLineDataset(size=batch_size)
    # variance = 0.1
    # source_mean = jnp.array([-1.0, 0.0])
    # target_mean = jnp.array([1.0, 0.0])
  
  elif geometry_str == "pipe":
    variance = 0.1
    source_mean = jnp.array([-1.0, 0.0])
    target_mean = jnp.array([1.0, 0.0])
    
  elif geometry_str == "stunnel":
    variance = 0.5
    source_mean = jnp.array([-11.0, -1.0])
    target_mean = jnp.array([11.0, 1.0])

  elif geometry_str == "gmm":
    return GaussianMixture("", batch_size=batch_size, rng=key, scale=1.0, std=1.0)

    
  return Gaussian(source_mean=source_mean, source_var=variance,
                  target_mean=target_mean, target_var=variance, batch_size=batch_size, init_key=key)


def create_sphere_ds(dim: int, sigma: float, batch_size: int, key):
  src_base = np.zeros((batch_size, dim))
  src_base[:, 2] = 1.

  trg_base = np.zeros((batch_size, dim))
  trg_base[:, 2] = -1.

  while True:
    key, src_key, trg_key = jax.random.split(key, 3)

    src_sample = jax.random.normal(src_key, (batch_size, dim)) * sigma + src_base
    src_sample /= np.linalg.norm(src_sample, axis=-1, keepdims=True)

    trg_sample = jax.random.normal(trg_key, (batch_size, dim)) * sigma + trg_base
    trg_sample /= np.linalg.norm(trg_sample, axis=-1, keepdims=True)

    yield {
      "src_lin": src_sample,
      "tgt_lin": trg_sample
    }
